Nested numerical solving problems

Scientific knowledge often takes the form of specific relationships expressed by systems of equations. For example:

If there is an analytic solution to the equation system, we can just include the solution in our statistical model like any other form of structural knowledge: easy! However, often we want to solve equations that are hard or impossible to solve analytically, but can be solved approximately using numerical methods.

This is tricky in the context of Hamiltonian Monte Carlo for two reasons:

  1. Computation: HMC requires many evaluations of the log probability density function and its gradients.
Important

At every evaluation, the sampler needs to solve the embedded equation system and find the gradients of the solution with respect to all model parameters.

  1. Extra source of error: how good of an approximation is good enough?

Reading:

Example

We have some tubes containing a substrate \(S\) and some biomass \(C\) that we think approximately follow the Monod equation for microbial growth:

\[\begin{align*} \frac{dC}{dt} &= \frac{\mu_{max}\cdot S(t)}{K_{S} + S(t)}\cdot C(t) \\ \frac{dS}{dt} &= -\gamma \cdot \frac{\mu_{max}\cdot S(t)}{K_{s} + S(t)} \cdot C(t) \end{align*}\]

We measured \(C\) and \(S\) at different timepoints in some experiments and we want to try and find out \(\mu_{max}\), \(K_{S}\) and \(\gamma\) for the different strains in the tubes.

You can read more about the Monod equation in Allen and Waclaw (2019).

What we know

\(\mu_{max}, K_S, \gamma, S, C\) are non-negative.

\(S(0)\) and \(C(0)\) vary a little by tube.

\(\mu_{max}, K_S, \gamma\) vary by strain.

Measurement noise is roughly proportional to measured quantity.

Statistical model

We use two regression models to describe the measurements:

\[\begin{align*} y_C &\sim LN(\ln{\hat{C}}, \sigma_{C}) \\ y_S &\sim LN(\ln{\hat{S}}, \sigma_{S}) \end{align*}\]

To capture the variation in parameters by tube and strain we add a hierarchical regression model:

\[\begin{align*} \ln{\mu_{max}} &\sim N(a_{\mu_{max}}, \tau_{\mu_max}) \\ \ln{\gamma} &\sim N(a_{gamma}, \tau_{\gamma}) \\ \ln{\mu_{K_S}} &\sim N(a_{K_S}, \tau_{K_S}) \end{align*}\]

To get a true abundance given some parameters we put an ode in the model:

\[ \hat{C}(t), \hat{S}(t) = \text{solve-monod-equation}(t, C_0, S_0, \mu_max, \gamma, K_S) \]

imports

import itertools

import arviz as az
import cmdstanpy
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt

Specify true parameters

In order to avoid doing too much annoying handling of strings we assume that all the parts of the problem have meaningful 1-indexed integer labels: for example, species 1 is biomass.

This code specifies the dimensions of our problem.

N_strain = 4
N_tube = 16
N_timepoint = 20
duration = 15
strains = [i+1 for i in range(N_strain)]
tubes = [i+1 for i in range(N_tube)]
species = [1, 2]
measurement_timepoint_ixs = [4, 7, 12, 15, 17]
timepoints = pd.Series(
    np.linspace(0.01, duration, N_timepoint),
    name="time",
    index=range(1, N_timepoint+1)
)
SEED = 12345
rng = np.random.default_rng(seed=SEED)

This code defines some true values for the parameters - we will use these to generate fake data.

true_param_values = {
    "a_mu_max": -1.7,
    "a_ks": -1.3,
    "a_gamma": -0.6,
    "t_mu_max": 0.2,
    "t_ks": 0.3,
    "t_gamma": 0.13,
    "species_zero": [
        [
            np.exp(np.random.normal(-2.1, 0.05)), 
            np.exp(np.random.normal(0.2, 0.05))
        ] for _ in range(N_tube)
    ],
    "sigma_y": [0.08, 0.1],
    "ln_mu_max_z": np.random.normal(0, 1, size=N_strain).tolist(),
    "ln_ks_z": np.random.normal(0, 1, size=N_strain).tolist(),
    "ln_gamma_z": np.random.normal(0, 1, size=N_strain).tolist(),
}
for var in ["mu_max", "ks", "gamma"]:
    true_param_values[var] = np.exp(
        true_param_values[f"a_{var}"]
        + true_param_values[f"t_{var}"] * np.array(true_param_values[f"ln_{var}_z"])
    ).tolist()

A bit of data transformation

This code does some handy transformations on the data using pandas, giving us a table of information about the measurements.

tube_to_strain = pd.Series(
    [
        (i % N_strain) + 1 for i in range(N_tube)  # % operator finds remainder
    ], index=tubes, name="strain"
)
measurements = (
    pd.DataFrame(
        itertools.product(tubes, measurement_timepoint_ixs, species),
        columns=["tube", "timepoint", "species"],
        index=range(1, len(tubes) * len(measurement_timepoint_ixs) * len(species) + 1)
    )
    .join(tube_to_strain, on="tube")
    .join(timepoints, on="timepoint")
)

Generating a Stan input dictionary

This code puts the data in the correct format for cmdstanpy.

stan_input_structure = {
    "N_measurement": len(measurements),
    "N_timepoint": N_timepoint,
    "N_tube": N_tube,
    "N_strain": N_strain,
    "tube": measurements["tube"].values.tolist(),
    "measurement_timepoint": measurements["timepoint"].values.tolist(),
    "measured_species": measurements["species"].values.tolist(),
    "strain": tube_to_strain.values.tolist(),
    "timepoint_time": timepoints.values.tolist(),
}

This code defines some prior distributions for the model’s parameters

priors = {
    # parameters that can be negative:
    "prior_a_mu_max": [-1.8, 0.2],
    "prior_a_ks": [-1.3, 0.1],
    "prior_a_gamma": [-0.5, 0.1],
    # parameters that are non-negative:
    "prior_t_mu_max": [-1.4, 0.1],
    "prior_t_ks": [-1.2, 0.1],
    "prior_t_gamma": [-2, 0.1],
    "prior_species_zero": [[[-2.1, 0.1], [0.2, 0.1]]] * N_tube,
    "prior_sigma_y": [[-2.3, 0.15], [-2.3, 0.15]],
}

The next bit of code lets us configure Stan’s interface to the Sundials ODE solver.

ode_solver_configuration = {
    "abs_tol": 1e-7,
    "rel_tol": 1e-7,
    "max_num_steps": int(1e7)
}

Now we can put all the inputs together

stan_input_common = stan_input_structure | priors | ode_solver_configuration

Load the model

This code loads the Stan program at monod.stan as a CmdStanModel object and compiles it using cmdstan’s compiler.

model = cmdstanpy.CmdStanModel(stan_file="../src/stan/monod.stan")
print(model.code())
functions {
  real get_mu_at_t(real mu_max, real ks, real S_at_t) {
    return (mu_max * S_at_t) / (ks + S_at_t);
  }
  vector ddt(real t, vector species, real mu_max, real ks, real gamma) {
    real mu_at_t = get_mu_at_t(mu_max, ks, species[2]);
    vector[2] out;
    out[1] = mu_at_t * species[1];
    out[2] = -gamma * mu_at_t * species[1];
    return out;
  }
}
data {
  int<lower=1> N_measurement;
  int<lower=1> N_timepoint;
  int<lower=1> N_tube;
  int<lower=1> N_strain;
  array[N_measurement] int<lower=1, upper=N_tube> tube;
  array[N_measurement] int<lower=1, upper=N_timepoint> measurement_timepoint;
  array[N_measurement] int<lower=1, upper=2> measured_species;
  vector<lower=0>[N_measurement] y;
  array[N_tube] int<lower=1, upper=N_strain> strain;
  array[N_timepoint] real<lower=0> timepoint_time;
  array[N_tube, 2] vector[2] prior_species_zero;
  array[2] vector[2] prior_sigma_y;
  vector[2] prior_a_mu_max;
  vector[2] prior_a_ks;
  vector[2] prior_a_gamma;
  vector[2] prior_t_mu_max;
  vector[2] prior_t_gamma;
  vector[2] prior_t_ks;
  real<lower=0> abs_tol;
  real<lower=0> rel_tol;
  int<lower=1> max_num_steps;
  int<lower=0, upper=1> likelihood;
}
parameters {
  vector[N_strain] ln_mu_max_z;
  vector[N_strain] ln_ks_z;
  vector[N_strain] ln_gamma_z;
  real a_mu_max;
  real a_ks;
  real a_gamma;
  real<lower=0> t_mu_max;
  real<lower=0> t_ks;
  real<lower=0> t_gamma;
  array[N_tube] vector<lower=0>[2] species_zero;
  vector<lower=0>[2] sigma_y;
}
transformed parameters {
  vector[N_strain] mu_max = exp(a_mu_max + ln_mu_max_z * t_mu_max);
  vector[N_strain] ks = exp(a_ks + ln_ks_z * t_ks);
  vector[N_strain] gamma = exp(a_gamma + ln_gamma_z * t_gamma);
  array[N_tube, N_timepoint] vector[2] abundance;
  for (tube_t in 1 : N_tube) {
    abundance[tube_t] = ode_bdf_tol(ddt, species_zero[tube_t], 0,
                                    timepoint_time,
                                    abs_tol, rel_tol, max_num_steps,
                                    mu_max[strain[tube_t]],
                                    ks[strain[tube_t]], gamma[strain[tube_t]]);
  }
}
model {
  // priors
  ln_mu_max_z ~ std_normal();
  ln_ks_z ~ std_normal();
  ln_gamma_z ~ std_normal();
  a_mu_max ~ normal(prior_a_mu_max[1], prior_a_mu_max[2]);
  a_ks ~ normal(prior_a_ks[1], prior_a_ks[2]);
  a_gamma ~ normal(prior_a_gamma[1], prior_a_gamma[2]);
  t_mu_max ~ lognormal(prior_t_mu_max[1], prior_t_mu_max[2]);
  t_ks ~ lognormal(prior_t_ks[1], prior_t_ks[2]);
  t_gamma ~ lognormal(prior_t_gamma[1], prior_t_gamma[2]);
  for (s in 1 : 2) {
    sigma_y[s] ~ lognormal(prior_sigma_y[s, 1], prior_sigma_y[s, 2]);
    for (t in 1 : N_tube){
      species_zero[t, s] ~ lognormal(prior_species_zero[t, s, 1],
                                     prior_species_zero[t, s, 2]);
    }
  }
  // likelihood
  if (likelihood) {
    for (m in 1 : N_measurement) {
      real yhat = abundance[tube[m], measurement_timepoint[m], measured_species[m]];
      y[m] ~ lognormal(log(yhat), sigma_y[measured_species[m]]);
    }
  }
}
generated quantities {
  vector[N_measurement] yrep;
  vector[N_measurement] llik;
  for (m in 1 : N_measurement){
    real yhat = abundance[tube[m], measurement_timepoint[m], measured_species[m]];
    yrep[m] = lognormal_rng(log(yhat), sigma_y[measured_species[m]]);
    llik[m] = lognormal_lpdf(y[m] | log(yhat), sigma_y[measured_species[m]]);
  }
}

Sample in fixed param mode to generate fake data

stan_input_true = stan_input_common | {
    "y": np.ones(len(measurements)).tolist(),  # dummy values as we don't need measurements yet
    "likelihood": 0                            # we don't need to evaluate the likelihood
}
coords = {
    "strain": strains,
    "tube": tubes,
    "species": species,
    "timepoint": timepoints.index.values,
    "measurement": measurements.index.values
}
dims = {
    "abundance": ["tube", "timepoint", "species"],
    "mu_max": ["strain"],
    "ks": ["strain"],
    "gamma": ["strain"],
    "species_zero": ["tube", "species"],
    "y": ["measurement"],
    "yrep": ["measurement"],
    "llik": ["measurement"]
}

mcmc_true = model.sample(
    data=stan_input_true,
    iter_sampling=1,
    fixed_param=True,
    chains=1,
    refresh=1,
    inits=true_param_values,
    seed=SEED,
)
idata_true = az.from_cmdstanpy(
    mcmc_true,
    dims=dims,
    coords=coords,
    posterior_predictive={"y": "yrep"},
    log_likelihood="llik"
)
17:03:49 - cmdstanpy - INFO - CmdStan start processing
17:03:49 - cmdstanpy - INFO - CmdStan done processing.
                                                                                

Look at results

def plot_sim(true_abundance, fake_measurements, species_to_ax):
    f, axes = plt.subplots(1, 2, figsize=[9, 3])

    axes[species_to_ax[1]].set_title("Species 1")
    axes[species_to_ax[2]].set_title("Species 2")
    for ax in axes:
        ax.set_xlabel("Time")
        ax.set_ylabel("Abundance")
        for (tube_i, species_i), df_i in true_abundance.groupby(["tube", "species"]):
            ax = axes[species_to_ax[species_i]]
            fm = df_i.merge(
                fake_measurements.drop("time", axis=1),
                on=["tube", "species", "timepoint"]
            )
            ax.plot(
                df_i.set_index("time")["abundance"], color="black", linewidth=0.5
            )
            ax.scatter(
                fm["time"],
                fm["simulated_measurement"],
                color="r",
                marker="x",
                label="simulated measurement"
            )
    return f, axes

species_to_ax = {1: 0, 2: 1}
true_abundance = (
    idata_true.posterior["abundance"]
    .to_dataframe()
    .droplevel(["chain", "draw"])
    .join(timepoints, on="timepoint")
    .reset_index()
)
fake_measurements = measurements.join(
    idata_true.posterior_predictive["yrep"]
    .to_series()
    .droplevel(["chain", "draw"])
    .rename("simulated_measurement")
).copy()
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)

f.savefig("img/monod_simulated_data.png")

Sample in prior mode

stan_input_prior = stan_input_common | {
    "y": fake_measurements["simulated_measurement"],
    "likelihood": 0
}
mcmc_prior = model.sample(
    data=stan_input_prior,
    iter_warmup=100,
    iter_sampling=100,
    chains=1,
    refresh=1,
    save_warmup=True,
    inits=true_param_values,
    seed=SEED,
)
idata_prior = az.from_cmdstanpy(
    mcmc_prior,
    dims=dims,
    coords=coords,
    posterior_predictive={"y": "yrep"},
    log_likelihood="llik"
)
idata_prior
17:03:50 - cmdstanpy - INFO - CmdStan start processing
17:04:36 - cmdstanpy - INFO - CmdStan done processing.
17:04:36 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
    Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -1: 
    Exception: lognormal_rng: Location parameter is nan, but must be finite! (in 'monod.stan', line 94, column 4 to column 69)
Consider re-running with show_console=True if the above output is unclear!
                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 564kB
      Dimensions:            (chain: 1, draw: 100, ln_mu_max_z_dim_0: 4,
                              ln_ks_z_dim_0: 4, ln_gamma_z_dim_0: 4, tube: 16,
                              species: 2, sigma_y_dim_0: 2, strain: 4, timepoint: 20)
      Coordinates:
        * chain              (chain) int64 8B 0
        * draw               (draw) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99
        * ln_mu_max_z_dim_0  (ln_mu_max_z_dim_0) int64 32B 0 1 2 3
        * ln_ks_z_dim_0      (ln_ks_z_dim_0) int64 32B 0 1 2 3
        * ln_gamma_z_dim_0   (ln_gamma_z_dim_0) int64 32B 0 1 2 3
        * tube               (tube) int64 128B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
        * species            (species) int64 16B 1 2
        * sigma_y_dim_0      (sigma_y_dim_0) int64 16B 0 1
        * strain             (strain) int64 32B 1 2 3 4
        * timepoint          (timepoint) int64 160B 1 2 3 4 5 6 ... 15 16 17 18 19 20
      Data variables: (12/15)
          ln_mu_max_z        (chain, draw, ln_mu_max_z_dim_0) float64 3kB -0.4693 ....
          ln_ks_z            (chain, draw, ln_ks_z_dim_0) float64 3kB -0.06056 ... ...
          ln_gamma_z         (chain, draw, ln_gamma_z_dim_0) float64 3kB -0.3497 .....
          a_mu_max           (chain, draw) float64 800B -1.465 -2.038 ... -2.008
          a_ks               (chain, draw) float64 800B -1.201 -1.241 ... -1.25 -1.403
          a_gamma            (chain, draw) float64 800B -0.5878 -0.5633 ... -0.5415
          ...                 ...
          species_zero       (chain, draw, tube, species) float64 26kB 0.1346 ... 1...
          sigma_y            (chain, draw, sigma_y_dim_0) float64 2kB 0.1206 ... 0....
          mu_max             (chain, draw, strain) float64 3kB 0.2081 ... 0.1613
          ks                 (chain, draw, strain) float64 3kB 0.2957 ... 0.2889
          gamma              (chain, draw, strain) float64 3kB 0.532 0.5814 ... 0.58
          abundance          (chain, draw, tube, timepoint, species) float64 512kB ...
      Attributes:
          created_at:                 2024-04-29T15:04:36.501119
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 130kB
      Dimensions:      (chain: 1, draw: 100, measurement: 160)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          yrep         (chain, draw, measurement) float64 128kB 0.2038 ... 0.7842
      Attributes:
          created_at:                 2024-04-29T15:04:36.507111
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 130kB
      Dimensions:      (chain: 1, draw: 100, measurement: 160)
      Coordinates:
        * chain        (chain) int64 8B 0
        * draw         (draw) int64 800B 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          llik         (chain, draw, measurement) float64 128kB 2.611 ... -26.84
      Attributes:
          created_at:                 2024-04-29T15:04:36.508049
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 11kB
      Dimensions:          (chain: 1, draw: 200)
      Coordinates:
        * chain            (chain) int64 8B 0
        * draw             (draw) int64 2kB 0 1 2 3 4 5 6 ... 194 195 196 197 198 199
      Data variables:
          lp               (chain, draw) float64 2kB -49.78 -49.78 ... -60.39 -61.47
          acceptance_rate  (chain, draw) float64 2kB 0.9419 0.0 0.0 ... 0.6652 0.674
          step_size        (chain, draw) float64 2kB 0.03125 12.94 ... 0.06485 0.06485
          tree_depth       (chain, draw) int64 2kB 7 0 0 2 8 7 6 6 ... 6 6 6 6 6 6 6 3
          n_steps          (chain, draw) int64 2kB 127 1 1 6 255 255 ... 63 63 63 63 7
          diverging        (chain, draw) bool 200B False True True ... False False
          energy           (chain, draw) float64 2kB 77.45 73.93 71.99 ... 81.05 89.4
      Attributes:
          created_at:                 2024-04-29T15:04:36.505485
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

We can find the prior intervals for the true abundance and plot them in the graph.

prior_abundances = idata_prior.posterior["abundance"]

n_sample = 20
chains = rng.choice(prior_abundances.coords["chain"].values, n_sample)
draws = rng.choice(prior_abundances.coords["draw"].values, n_sample)
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)

for ax, species_i in zip(axes, species):
    for tube_j in tubes:
        for chain, draw in zip(chains, draws):
            timeseries = prior_abundances.sel(chain=chain, draw=draw, tube=tube_j, species=species_i)
            ax.plot(
                timepoints.values, 
                timeseries.values,
                alpha=0.5, color="skyblue", zorder=-1
            )
f.savefig("img/monod_priors.png")

Sample in posterior mode

stan_input_posterior = stan_input_common | {
    "y": fake_measurements["simulated_measurement"],
    "likelihood": 1
}
mcmc_posterior = model.sample(
    data=stan_input_posterior,
    iter_warmup=300,
    iter_sampling=300,
    chains=4,
    refresh=1,
    inits=true_param_values,
    seed=SEED,
)
idata_posterior = az.from_cmdstanpy(
    mcmc_posterior,
    dims=dims,
    coords=coords,
    posterior_predictive={"y": "yrep"},
    log_likelihood="llik"
)
idata_posterior
17:04:37 - cmdstanpy - INFO - CmdStan start processing
17:10:39 - cmdstanpy - INFO - CmdStan done processing.
17:10:39 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
    Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: initial state[1] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: initial state[1] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
    Exception: CVode(cvodes_mem, t_final, nv_state_, &t_init, CV_NORMAL) failed with error flag -4: 
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: ode_bdf_tol: initial state[1] is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: ode_bdf_tol: ode parameters and data is inf, but must be finite! (in 'monod.stan', line 56, column 4 to line 60, column 79)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
    Exception: lognormal_lpdf: Location parameter is nan, but must be finite! (in 'monod.stan', line 85, column 6 to column 64)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 7MB
      Dimensions:            (chain: 4, draw: 300, ln_mu_max_z_dim_0: 4,
                              ln_ks_z_dim_0: 4, ln_gamma_z_dim_0: 4, tube: 16,
                              species: 2, sigma_y_dim_0: 2, strain: 4, timepoint: 20)
      Coordinates:
        * chain              (chain) int64 32B 0 1 2 3
        * draw               (draw) int64 2kB 0 1 2 3 4 5 ... 294 295 296 297 298 299
        * ln_mu_max_z_dim_0  (ln_mu_max_z_dim_0) int64 32B 0 1 2 3
        * ln_ks_z_dim_0      (ln_ks_z_dim_0) int64 32B 0 1 2 3
        * ln_gamma_z_dim_0   (ln_gamma_z_dim_0) int64 32B 0 1 2 3
        * tube               (tube) int64 128B 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
        * species            (species) int64 16B 1 2
        * sigma_y_dim_0      (sigma_y_dim_0) int64 16B 0 1
        * strain             (strain) int64 32B 1 2 3 4
        * timepoint          (timepoint) int64 160B 1 2 3 4 5 6 ... 15 16 17 18 19 20
      Data variables: (12/15)
          ln_mu_max_z        (chain, draw, ln_mu_max_z_dim_0) float64 38kB -0.986 ....
          ln_ks_z            (chain, draw, ln_ks_z_dim_0) float64 38kB -0.3128 ... ...
          ln_gamma_z         (chain, draw, ln_gamma_z_dim_0) float64 38kB 1.868 ......
          a_mu_max           (chain, draw) float64 10kB -1.538 -1.456 ... -1.439
          a_ks               (chain, draw) float64 10kB -1.294 -1.327 ... -1.217
          a_gamma            (chain, draw) float64 10kB -0.6778 -0.6196 ... -0.5872
          ...                 ...
          species_zero       (chain, draw, tube, species) float64 307kB 0.1234 ... ...
          sigma_y            (chain, draw, sigma_y_dim_0) float64 19kB 0.08303 ... ...
          mu_max             (chain, draw, strain) float64 38kB 0.1683 ... 0.239
          ks                 (chain, draw, strain) float64 38kB 0.2512 ... 0.2017
          gamma              (chain, draw, strain) float64 38kB 0.6667 ... 0.5471
          abundance          (chain, draw, tube, timepoint, species) float64 6MB 0....
      Attributes:
          created_at:                 2024-04-29T15:10:39.435568
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 300, measurement: 160)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 2kB 0 1 2 3 4 5 6 ... 293 294 295 296 297 298 299
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          yrep         (chain, draw, measurement) float64 2MB 0.165 1.264 ... 0.2912
      Attributes:
          created_at:                 2024-04-29T15:10:39.442620
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 2MB
      Dimensions:      (chain: 4, draw: 300, measurement: 160)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 2kB 0 1 2 3 4 5 6 ... 293 294 295 296 297 298 299
        * measurement  (measurement) int64 1kB 1 2 3 4 5 6 ... 155 156 157 158 159 160
      Data variables:
          llik         (chain, draw, measurement) float64 2MB 3.021 -0.1452 ... 0.8522
      Attributes:
          created_at:                 2024-04-29T15:10:39.443956
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 61kB
      Dimensions:          (chain: 4, draw: 300)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 2kB 0 1 2 3 4 5 6 ... 294 295 296 297 298 299
      Data variables:
          lp               (chain, draw) float64 10kB 99.26 93.29 ... 97.05 96.51
          acceptance_rate  (chain, draw) float64 10kB 0.9906 0.9414 ... 0.9219 0.9104
          step_size        (chain, draw) float64 10kB 0.02773 0.02773 ... 0.03921
          tree_depth       (chain, draw) int64 10kB 7 7 7 7 7 7 7 7 ... 7 7 7 6 6 7 7
          n_steps          (chain, draw) int64 10kB 127 127 255 127 ... 63 63 127 127
          diverging        (chain, draw) bool 1kB False False False ... False False
          energy           (chain, draw) float64 10kB -72.0 -68.86 ... -72.84 -68.01
      Attributes:
          created_at:                 2024-04-29T15:10:39.440206
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

Diagnostics: is the posterior ok?

First check the sample_stats group to see if there were any divergent transitions and if the lp parameter converged.

az.summary(idata_posterior.sample_stats)
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/stats/diagnostics.py:592: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
lp 99.182 5.730 88.560 109.161 0.288 0.204 413.0 634.0 1.010000e+00
acceptance_rate 0.936 0.079 0.783 1.000 0.002 0.002 1308.0 1111.0 1.010000e+00
step_size 0.034 0.004 0.028 0.039 0.002 0.002 4.0 4.0 5.859337e+15
tree_depth 6.966 0.211 7.000 7.000 0.022 0.016 91.0 947.0 1.050000e+00
n_steps 131.213 27.896 63.000 127.000 4.238 3.017 49.0 53.0 1.070000e+00
diverging 0.000 0.000 0.000 0.000 0.000 0.000 1200.0 1200.0 NaN
energy -72.876 7.834 -87.030 -58.601 0.386 0.274 413.0 730.0 1.010000e+00

Next check the parameter-by-parameter summary

az.summary(idata_posterior)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
ln_mu_max_z[0] -0.787 0.516 -1.744 0.127 0.020 0.014 661.0 798.0 1.01
ln_mu_max_z[1] 0.402 0.509 -0.533 1.364 0.020 0.014 624.0 773.0 1.00
ln_mu_max_z[2] 0.777 0.513 -0.230 1.665 0.021 0.015 586.0 730.0 1.01
ln_mu_max_z[3] 0.886 0.518 -0.050 1.884 0.021 0.015 607.0 719.0 1.00
ln_ks_z[0] -0.009 1.006 -1.807 1.813 0.024 0.029 1733.0 951.0 1.00
... ... ... ... ... ... ... ... ... ...
abundance[16, 18, 2] 0.229 0.027 0.183 0.283 0.001 0.001 1476.0 1055.0 1.00
abundance[16, 19, 1] 1.799 0.087 1.648 1.981 0.002 0.001 1999.0 1080.0 1.00
abundance[16, 19, 2] 0.148 0.026 0.097 0.197 0.001 0.001 1276.0 1061.0 1.00
abundance[16, 20, 1] 1.909 0.096 1.733 2.100 0.002 0.002 1864.0 1087.0 1.00
abundance[16, 20, 2] 0.085 0.024 0.038 0.125 0.001 0.000 1188.0 758.0 1.00

704 rows × 9 columns

Show posterior intervals

prior_abundances = idata_posterior.posterior["abundance"]

n_sample = 20
chains = rng.choice(prior_abundances.coords["chain"].values, n_sample)
draws = rng.choice(prior_abundances.coords["draw"].values, n_sample)
f, axes = plot_sim(true_abundance, fake_measurements, species_to_ax)

for ax, species_i in zip(axes, species):
    for tube_j in tubes:
        for chain, draw in zip(chains, draws):
            timeseries = prior_abundances.sel(chain=chain, draw=draw, tube=tube_j, species=species_i)
            ax.plot(
                timepoints.values, 
                timeseries.values,
                alpha=0.5, color="skyblue", zorder=-1
            )
f.savefig("img/monod_posteriors.png")

look at the posterior

The next few cells use arviz’s plot_posterior function to plot the marginal posterior distributions for some of the model’s parameters:

f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
    idata_posterior,
    kind="hist",
    bins=20,
    var_names=["gamma"],
    ax=axes,
    point_estimate=None,
    hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["gamma"]):
    ax.axvline(true_value, color="red")

f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
    idata_posterior,
    kind="hist",
    bins=20,
    var_names=["mu_max"],
    ax=axes,
    point_estimate=None,
    hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["mu_max"]):
    ax.axvline(true_value, color="red")

f, axes = plt.subplots(1, 4, figsize=[10, 4])
axes = az.plot_posterior(
    idata_posterior,
    kind="hist",
    bins=20,
    var_names=["ks"],
    ax=axes,
    point_estimate=None,
    hdi_prob="hide"
)
for ax, true_value in zip(axes, true_param_values["ks"]):
    ax.axvline(true_value, color="red")

References

Allen, Rosalind J, and Bartłomiej Waclaw. 2019. “Bacterial Growth: A Statistical Physicist’s Guide.” Reports on Progress in Physics. Physical Society (Great Britain) 82 (1): 016601. https://doi.org/10.1088/1361-6633/aae546.
Timonen, Juho, Nikolas Siccha, Ben Bales, Harri Lähdesmäki, and Aki Vehtari. 2022. “An Importance Sampling Approach for Reliable and Efficient Inference in Bayesian Ordinary Differential Equation Models.” arXiv. https://doi.org/10.48550/arXiv.2205.09059.